"""Example of hierarchical training using the multi-agent API.

The example env is that of a "windy maze". The agent observes the current wind
direction and can either choose to stand still, or move in that direction.

You can try out the env directly with:

    $ python hierarchical_training.py --flat

A simple hierarchical formulation involves a high-level agent that issues goals
(i.e., go north / south / east / west), and a low-level agent that executes
these goals over a number of time-steps. This can be implemented as a
multi-agent environment with a top-level agent and low-level agents spawned
for each higher-level action. The lower level agent is rewarded for moving
in the right direction.

You can try this formulation with:

    $ python hierarchical_training.py  # gets ~100 rew after ~100k timesteps

Note that the hierarchical formulation actually converges slightly slower than
using --flat in this example.
"""

import argparse
import random
import gym
from gym.spaces import Box, Discrete, Tuple
import logging

import ray
from ray import tune
from ray.tune import function
from ray.rllib.env import MultiAgentEnv

parser = argparse.ArgumentParser()
parser.add_argument("--flat", action="store_true")

# Agent has to traverse the maze from the starting position S -> F
# Observation space [x_pos, y_pos, wind_direction]
# Action space: stay still OR move in current wind direction
MAP_DATA = """
#########
#S      #
####### #
      # #
      # #
####### #
#F      #
#########"""

logger = logging.getLogger(__name__)


class WindyMazeEnv(gym.Env):
    def __init__(self, env_config):
        self.map = [m for m in MAP_DATA.split("\n") if m]
        self.x_dim = len(self.map)
        self.y_dim = len(self.map[0])
        logger.info("Loaded map {} {}".format(self.x_dim, self.y_dim))
        for x in range(self.x_dim):
            for y in range(self.y_dim):
                if self.map[x][y] == "S":
                    self.start_pos = (x, y)
                elif self.map[x][y] == "F":
                    self.end_pos = (x, y)
        logger.info("Start pos {} end pos {}".format(self.start_pos,
                                                     self.end_pos))
        self.observation_space = Tuple([
            Box(0, 100, shape=(2, )),  # (x, y)
            Discrete(4),  # wind direction (N, E, S, W)
        ])
        self.action_space = Discrete(2)  # whether to move or not

    def reset(self):
        self.wind_direction = random.choice([0, 1, 2, 3])
        self.pos = self.start_pos
        self.num_steps = 0
        return [[self.pos[0], self.pos[1]], self.wind_direction]

    def step(self, action):
        if action == 1:
            self.pos = self._get_new_pos(self.pos, self.wind_direction)
        self.num_steps += 1
        self.wind_direction = random.choice([0, 1, 2, 3])
        at_goal = self.pos == self.end_pos
        done = at_goal or self.num_steps >= 200
        return ([[self.pos[0], self.pos[1]], self.wind_direction],
                100 * int(at_goal), done, {})

    def _get_new_pos(self, pos, direction):
        if direction == 0:
            new_pos = (pos[0] - 1, pos[1])
        elif direction == 1:
            new_pos = (pos[0], pos[1] + 1)
        elif direction == 2:
            new_pos = (pos[0] + 1, pos[1])
        elif direction == 3:
            new_pos = (pos[0], pos[1] - 1)
        if (new_pos[0] >= 0 and new_pos[0] < self.x_dim and new_pos[1] >= 0
                and new_pos[1] < self.y_dim
                and self.map[new_pos[0]][new_pos[1]] != "#"):
            return new_pos
        else:
            return pos  # did not move


class HierarchicalWindyMazeEnv(MultiAgentEnv):
    def __init__(self, env_config):
        self.flat_env = WindyMazeEnv(env_config)

    def reset(self):
        self.cur_obs = self.flat_env.reset()
        self.current_goal = None
        self.steps_remaining_at_level = None
        self.num_high_level_steps = 0
        # current low level agent id. This must be unique for each high level
        # step since agent ids cannot be reused.
        self.low_level_agent_id = "low_level_{}".format(
            self.num_high_level_steps)
        return {
            "high_level_agent": self.cur_obs,
        }

    def step(self, action_dict):
        assert len(action_dict) == 1, action_dict
        if "high_level_agent" in action_dict:
            return self._high_level_step(action_dict["high_level_agent"])
        else:
            return self._low_level_step(list(action_dict.values())[0])

    def _high_level_step(self, action):
        logger.debug("High level agent sets goal".format(action))
        self.current_goal = action
        self.steps_remaining_at_level = 25
        self.num_high_level_steps += 1
        self.low_level_agent_id = "low_level_{}".format(
            self.num_high_level_steps)
        obs = {self.low_level_agent_id: [self.cur_obs, self.current_goal]}
        rew = {self.low_level_agent_id: 0}
        done = {"__all__": False}
        return obs, rew, done, {}

    def _low_level_step(self, action):
        logger.debug("Low level agent step {}".format(action))
        self.steps_remaining_at_level -= 1
        cur_pos = tuple(self.cur_obs[0])
        goal_pos = self.flat_env._get_new_pos(cur_pos, self.current_goal)

        # Step in the actual env
        f_obs, f_rew, f_done, _ = self.flat_env.step(action)
        new_pos = tuple(f_obs[0])
        self.cur_obs = f_obs

        # Calculate low-level agent observation and reward
        obs = {self.low_level_agent_id: [f_obs, self.current_goal]}
        if new_pos != cur_pos:
            if new_pos == goal_pos:
                rew = {self.low_level_agent_id: 1}
            else:
                rew = {self.low_level_agent_id: -1}
        else:
            rew = {self.low_level_agent_id: 0}

        # Handle env termination & transitions back to higher level
        done = {"__all__": False}
        if f_done:
            done["__all__"] = True
            logger.debug("high level final reward {}".format(f_rew))
            rew["high_level_agent"] = f_rew
            obs["high_level_agent"] = f_obs
        elif self.steps_remaining_at_level == 0:
            done[self.low_level_agent_id] = True
            rew["high_level_agent"] = 0
            obs["high_level_agent"] = f_obs

        return obs, rew, done, {}


if __name__ == "__main__":
    args = parser.parse_args()
    ray.init()
    if args.flat:
        tune.run(
            "PPO",
            config={
                "env": WindyMazeEnv,
                "num_workers": 0,
            },
        )
    else:
        maze = WindyMazeEnv(None)

        def policy_mapping_fn(agent_id):
            if agent_id.startswith("low_level_"):
                return "low_level_policy"
            else:
                return "high_level_policy"

        tune.run(
            "PPO",
            config={
                "env": HierarchicalWindyMazeEnv,
                "num_workers": 0,
                "log_level": "INFO",
                "entropy_coeff": 0.01,
                "multiagent": {
                    "policies": {
                        "high_level_policy": (None, maze.observation_space,
                                              Discrete(4), {
                                                  "gamma": 0.9
                                              }),
                        "low_level_policy": (None,
                                             Tuple([
                                                 maze.observation_space,
                                                 Discrete(4)
                                             ]), maze.action_space, {
                                                 "gamma": 0.0
                                             }),
                    },
                    "policy_mapping_fn": function(policy_mapping_fn),
                },
            },
        )
